{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# COVID-19 Drug Repurposing via gene-compounds relations\n",
    "This example shows how to do drug repurposing using DRKG even with the pretrained model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collecting COVID-19 related disease\n",
    "At the very beginning we need to collect a list of associated genes for Corona-Virus(COV) in DRKG. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "442\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "file='coronavirus-related-host-genes.tsv'\n",
    "df = pd.read_csv(file, sep=\"\\t\")\n",
    "cov_genes = np.unique(df.values[:,2]).tolist()\n",
    "file='covid19-host-genes.tsv'\n",
    "df = pd.read_csv(file, sep=\"\\t\")\n",
    "cov2_genes = np.unique(df.values[:,2]).tolist()\n",
    "# keep unique related genes\n",
    "\n",
    "cov_related_genes=list(set(cov_genes+cov2_genes))\n",
    "#cov_related_genes=list(set(cov2_genes))\n",
    "print(len(cov_related_genes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate drugs\n",
    "Now we use FDA-approved drugs in Drugbank as candidate drugs. (we exclude drugs with molecule weight < 250) The drug list is in infer\\_drug.tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "# Load entity file\n",
    "drug_list = []\n",
    "with open(\"./infer_drug.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['drug','ids'])\n",
    "    for row_val in reader:\n",
    "        drug_list.append(row_val['drug'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8104"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(drug_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inhibits relation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One inhibit relation in this context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "treatment = ['GNBR::N::Compound:Gene']#'DRUGBANK::target::Compound:Gene','DGIDB::INHIBITOR::Gene:Compound']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get pretrained model\n",
    "We can directly use the pretrianed model to do drug repurposing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "import csv\n",
    "sys.path.insert(1, '../utils')\n",
    "from utils import download_and_extract\n",
    "download_and_extract()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_idmap_file = '../data/drkg/embed/entities.tsv'\n",
    "relation_idmap_file = '../data/drkg/embed/relations.tsv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get embeddings for genes and drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get drugname/disease name to entity ID mappings\n",
    "entity_map = {}\n",
    "entity_id_map = {}\n",
    "relation_map = {}\n",
    "with open(entity_idmap_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['name','id'])\n",
    "    for row_val in reader:\n",
    "        entity_map[row_val['name']] = int(row_val['id'])\n",
    "        entity_id_map[int(row_val['id'])] = row_val['name']\n",
    "        \n",
    "with open(relation_idmap_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['name','id'])\n",
    "    for row_val in reader:\n",
    "        relation_map[row_val['name']] = int(row_val['id'])\n",
    "        \n",
    "# handle the ID mapping\n",
    "drug_ids = []\n",
    "gene_ids = []\n",
    "for drug in drug_list:\n",
    "    drug_ids.append(entity_map[drug])\n",
    "    \n",
    "for gene in cov_related_genes:\n",
    "    gene_ids.append(entity_map[gene])\n",
    "\n",
    "treatment_rid = [relation_map[treat]  for treat in treatment]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "import torch as th\n",
    "entity_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_entity.npy')\n",
    "rel_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_relation.npy')\n",
    "\n",
    "drug_ids = th.tensor(drug_ids).long()\n",
    "gene_ids = th.tensor(gene_ids).long()\n",
    "treatment_rid = th.tensor(treatment_rid)\n",
    "\n",
    "drug_emb = th.tensor(entity_emb[drug_ids])\n",
    "treatment_embs = [th.tensor(rel_emb[rid]) for rid in treatment_rid]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Drug Repurposing Based on Edge Score\n",
    "We use following algorithm to calculate the edge score. Note, here we use logsigmiod to make all scores < 0. The larger the score is, the stronger the $h$ will have $r$ with $t$.\n",
    "\n",
    "$\\mathbf{d} = \\gamma - ||\\mathbf{h}+\\mathbf{r}-\\mathbf{t}||_{2}$\n",
    "\n",
    "$\\mathbf{score} = \\log\\left(\\frac{1}{1+\\exp(\\mathbf{-d})}\\right)$\n",
    "\n",
    "When doing drug repurposing, we only use the treatment related relations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as fn\n",
    "\n",
    "gamma=12.0\n",
    "def transE_l2(head, rel, tail):\n",
    "    score = head + rel - tail\n",
    "    return gamma - th.norm(score, p=2, dim=-1)\n",
    "\n",
    "scores_per_gene = []\n",
    "dids_per_gene = []\n",
    "for rid in range(len(treatment_embs)):\n",
    "    treatment_emb=treatment_embs[rid]\n",
    "    for gene_id in gene_ids:\n",
    "        gene_emb = th.tensor(entity_emb[gene_id])\n",
    "        if treatment[rid]=='DGIDB::INHIBITOR::Gene:Compound':\n",
    "            score = fn.logsigmoid(transE_l2(gene_emb, treatment_emb,\n",
    "                                        drug_emb))\n",
    "        else:\n",
    "            score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb,\n",
    "                                            gene_emb))\n",
    "        scores_per_gene.append(score)\n",
    "        dids_per_gene.append(drug_ids)\n",
    "scores = th.cat(scores_per_gene)\n",
    "dids = th.cat(dids_per_gene)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check clinical trial drugs per gene\n",
    "Here we load the clinical trial drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "clinical_drugs_file = './COVID19_clinical_trial_drugs.tsv'\n",
    "clinical_drug_map = {}\n",
    "with open(clinical_drugs_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'drug_name','drug_id'])\n",
    "    for row_val in reader:\n",
    "        clinical_drug_map[row_val['drug_id']] = row_val['drug_name']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we measure some statistics per gene."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gene::6441\t9\t\n",
      "[0]Dexamethasone\n",
      "[29]Methylprednisolone\n",
      "[30]Ribavirin\n",
      "[40]Thalidomide\n",
      "[46]Chloroquine\n",
      "[77]Losartan\n",
      "[86]Sargramostim\n",
      "[88]Azithromycin\n",
      "[90]Hydroxychloroquine\n",
      "\n",
      "DB01234\tDexamethasone\t401\t17.424322932617844\n",
      "DB01041\tThalidomide\t336\t9.52602832899466\n",
      "DB00608\tChloroquine\t258\t5.281556104219857\n",
      "DB00746\tDeferoxamine\t111\t2.3803197362314727\n",
      "DB01394\tColchicine\t108\t1.9397152439066307\n",
      "DB00959\tMethylprednisolone\t105\t1.6800667504790185\n",
      "DB00678\tLosartan\t92\t1.9905909204249115\n",
      "DB00811\tRibavirin\t92\t2.0302922908647756\n",
      "DB08877\tRuxolitinib\t47\t0.7744534092963637\n",
      "DB08895\tTofacitinib\t33\t0.46233716095307054\n",
      "DB01611\tHydroxychloroquine\t14\t0.20167135495496702\n",
      "DB05511\tPiclidenoson\t6\t0.1513038675225646\n",
      "DB00207\tAzithromycin\t5\t0.05829492617697397\n",
      "DB00198\tOseltamivir\t1\t0.2\n",
      "DB00020\tSargramostim\t1\t0.011494252873563218\n"
     ]
    }
   ],
   "source": [
    "maxhit=0\n",
    "drugs_in_top_k={}\n",
    "drugsfr_in_top_k={}\n",
    "for i in range(len(scores_per_gene)):\n",
    "    score=scores_per_gene[i]\n",
    "    did=dids_per_gene[i]\n",
    "    idx = th.flip(th.argsort(score), dims=[0])\n",
    "    score = score[idx].numpy()\n",
    "    did = did[idx].numpy()\n",
    "    #print(did)\n",
    "    _, unique_indices = np.unique(did, return_index=True)\n",
    "    topk=100\n",
    "    topk_indices = np.sort(unique_indices)[:topk]\n",
    "    proposed_did = did[topk_indices]\n",
    "    proposed_score = score[topk_indices]\n",
    "    found_in_top_k=0\n",
    "    found_drugs=\"\\n\"\n",
    "    for j in range(topk):\n",
    "        drug = entity_id_map[int(proposed_did[j])][10:17]\n",
    "        if clinical_drug_map.get(drug, None) is not None:\n",
    "            found_in_top_k+=1\n",
    "            score = proposed_score[j]\n",
    "            if drug in drugs_in_top_k:\n",
    "                drugs_in_top_k[drug]+=1\n",
    "                drugsfr_in_top_k[drug]+=1/(j+1)\n",
    "            else:\n",
    "                drugs_in_top_k[drug]=1\n",
    "                drugsfr_in_top_k[drug]=1/(j+1)\n",
    "            found_drugs+=\"[{}]{}\\n\".format(j, clinical_drug_map[drug])\n",
    "            #print(\"[{}]{}\".format(j, clinical_drug_map[drug]))\n",
    "    #print(\"{}\\t{}\".format(cov_related_genes[i], found_in_top_k))\n",
    "    if maxhit< found_in_top_k:\n",
    "        maxhit=found_in_top_k\n",
    "        maxgene=cov_related_genes[i]\n",
    "        max_dugs=found_drugs\n",
    "print(\"{}\\t{}\\t{}\".format(maxgene, maxhit,max_dugs))\n",
    "\n",
    "res=[[drug, clinical_drug_map[drug] ,drugs_in_top_k[drug],drugsfr_in_top_k[drug]] for drug in drugs_in_top_k.keys()]\n",
    "res=reversed(sorted(res, key=lambda x : x[2]))\n",
    "for drug in res:\n",
    "    print(\"{}\\t{}\\t{}\\t{}\".format(drug[0], drug[1] ,drug[2],drug[3]))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
